package services

import (
	"context"
	"fmt"
	"strings"
	"sync"

	"github.com/docker/docker/client"
	"github.com/rancher/rke/docker"
	"github.com/rancher/rke/hosts"
	"github.com/rancher/rke/k8s"
	"github.com/rancher/rke/log"
	"github.com/rancher/rke/pki"
	"github.com/rancher/rke/util"
	v3 "github.com/rancher/types/apis/management.cattle.io/v3"
	"github.com/sirupsen/logrus"
	"golang.org/x/sync/errgroup"
	apierrors "k8s.io/apimachinery/pkg/api/errors"
	k8sutil "k8s.io/apimachinery/pkg/util/intstr"
	"k8s.io/client-go/kubernetes"
	"k8s.io/kubectl/pkg/drain"
)

const (
	unschedulableEtcdTaint    = "node-role.kubernetes.io/etcd=true:NoExecute"
	unschedulableControlTaint = "node-role.kubernetes.io/controlplane=true:NoSchedule"
)

func RunWorkerPlane(ctx context.Context, allHosts []*hosts.Host, localConnDialerFactory hosts.DialerFactory, prsMap map[string]v3.PrivateRegistry, workerNodePlanMap map[string]v3.RKEConfigNodePlan, certMap map[string]pki.CertificatePKI, updateWorkersOnly bool, alpineImage string) error {
	log.Infof(ctx, "[%s] Building up Worker Plane..", WorkerRole)
	var errgrp errgroup.Group

	hostsQueue := util.GetObjectQueue(allHosts)
	for w := 0; w < WorkerThreads; w++ {
		errgrp.Go(func() error {
			var errList []error
			for host := range hostsQueue {
				runHost := host.(*hosts.Host)
				err := doDeployWorkerPlaneHost(ctx, runHost, localConnDialerFactory, prsMap, workerNodePlanMap[runHost.Address].Processes, certMap, updateWorkersOnly, alpineImage)
				if err != nil {
					errList = append(errList, err)
				}
			}
			return util.ErrList(errList)
		})
	}

	if err := errgrp.Wait(); err != nil {
		return err
	}
	log.Infof(ctx, "[%s] Successfully started Worker Plane..", WorkerRole)
	return nil
}

func UpgradeWorkerPlane(ctx context.Context, kubeClient *kubernetes.Clientset, multipleRolesHosts []*hosts.Host, workerOnlyHosts []*hosts.Host, inactiveHosts []*hosts.Host, localConnDialerFactory hosts.DialerFactory, prsMap map[string]v3.PrivateRegistry, workerNodePlanMap map[string]v3.RKEConfigNodePlan, certMap map[string]pki.CertificatePKI, updateWorkersOnly bool, alpineImage string, upgradeStrategy *v3.NodeUpgradeStrategy, newHosts map[string]bool) (string, error) {
	log.Infof(ctx, "[%s] Upgrading Worker Plane..", WorkerRole)
	var errMsgMaxUnavailableNotFailed string
	maxUnavailable, err := CalculateMaxUnavailable(upgradeStrategy.MaxUnavailable, len(workerOnlyHosts))
	if err != nil {
		return errMsgMaxUnavailableNotFailed, err
	}
	if maxUnavailable > WorkerThreads {
		/* upgrading a large number of nodes in parallel leads to a large number of goroutines, which has led to errors regarding too many open sockets
		Because of this RKE switched to using workerpools. 50 workerthreads has been sufficient to optimize rke up, upgrading at most 50 nodes in parallel.
		So the user configurable maxUnavailable will be respected only as long as it's less than 50 and capped at 50 */
		maxUnavailable = WorkerThreads
		logrus.Info("Setting maxUnavailable to 50, to avoid issues related to upgrading large number of nodes in parallel")
	}

	maxUnavailable -= len(inactiveHosts)

	updateNewHostsList(kubeClient, append(multipleRolesHosts, workerOnlyHosts...), newHosts)
	log.Infof(ctx, "First checking and processing worker components for upgrades on nodes with etcd/controlplane roles one at a time")
	multipleRolesHostsFailedToUpgrade, err := processWorkerPlaneForUpgrade(ctx, kubeClient, multipleRolesHosts, localConnDialerFactory, prsMap, workerNodePlanMap, certMap, updateWorkersOnly, alpineImage, 1, upgradeStrategy, newHosts)
	if err != nil {
		logrus.Errorf("Failed to upgrade hosts: %v with error %v", strings.Join(multipleRolesHostsFailedToUpgrade, ","), err)
		return errMsgMaxUnavailableNotFailed, err
	}

	log.Infof(ctx, "Now checking and upgrading worker components on nodes with only worker role %v at a time", maxUnavailable)
	workerOnlyHostsFailedToUpgrade, err := processWorkerPlaneForUpgrade(ctx, kubeClient, workerOnlyHosts, localConnDialerFactory, prsMap, workerNodePlanMap, certMap, updateWorkersOnly, alpineImage, maxUnavailable, upgradeStrategy, newHosts)
	if err != nil {
		logrus.Errorf("Failed to upgrade hosts: %v with error %v", strings.Join(workerOnlyHostsFailedToUpgrade, ","), err)
		if len(workerOnlyHostsFailedToUpgrade) >= maxUnavailable {
			return errMsgMaxUnavailableNotFailed, err
		}
		errMsgMaxUnavailableNotFailed = fmt.Sprintf("Failed to upgrade hosts: %v with error %v", strings.Join(workerOnlyHostsFailedToUpgrade, ","), err)
	}

	log.Infof(ctx, "[%s] Successfully upgraded Worker Plane..", WorkerRole)
	return errMsgMaxUnavailableNotFailed, nil
}

func CalculateMaxUnavailable(maxUnavailableVal string, numHosts int) (int, error) {
	// if maxUnavailable is given in percent, round down
	maxUnavailableParsed := k8sutil.Parse(maxUnavailableVal)
	logrus.Debugf("Provided value for maxUnavailable: %v", maxUnavailableParsed)
	maxUnavailable, err := k8sutil.GetValueFromIntOrPercent(&maxUnavailableParsed, numHosts, false)
	if err != nil {
		logrus.Errorf("Unable to parse max_unavailable, should be a number or percentage of nodes, error: %v", err)
		return 0, err
	}
	if maxUnavailable == 0 {
		// In case there is only one node and rounding down maxUnvailable percentage led to 0
		maxUnavailable = 1
	}
	logrus.Infof("%v worker nodes can be unavailable at a time", maxUnavailable)
	return maxUnavailable, nil
}

func updateNewHostsList(kubeClient *kubernetes.Clientset, allHosts []*hosts.Host, newHosts map[string]bool) {
	for _, h := range allHosts {
		_, err := k8s.GetNode(kubeClient, h.HostnameOverride)
		if err != nil && apierrors.IsNotFound(err) {
			// this host could have been added to cluster state upon successful controlplane upgrade but isn't a node yet.
			newHosts[h.HostnameOverride] = true
		}
	}
}

func processWorkerPlaneForUpgrade(ctx context.Context, kubeClient *kubernetes.Clientset, allHosts []*hosts.Host, localConnDialerFactory hosts.DialerFactory, prsMap map[string]v3.PrivateRegistry, workerNodePlanMap map[string]v3.RKEConfigNodePlan, certMap map[string]pki.CertificatePKI, updateWorkersOnly bool, alpineImage string,
	maxUnavailable int, upgradeStrategy *v3.NodeUpgradeStrategy, newHosts map[string]bool) ([]string, error) {
	var errgrp errgroup.Group
	var drainHelper drain.Helper
	var failedHosts []string
	var hostsFailedToUpgrade = make(chan string, maxUnavailable)
	var hostsFailed sync.Map

	hostsQueue := util.GetObjectQueue(allHosts)
	if upgradeStrategy.Drain {
		drainHelper = getDrainHelper(kubeClient, *upgradeStrategy)
	}
	/* Each worker thread starts a goroutine that reads the hostsQueue channel in a for loop
	Using same number of worker threads as maxUnavailable ensures only maxUnavailable number of nodes are being processed at a time
	Node is done upgrading only after it is listed as ready and uncordoned.*/
	for w := 0; w < maxUnavailable; w++ {
		errgrp.Go(func() error {
			var errList []error
			for host := range hostsQueue {
				runHost := host.(*hosts.Host)
				logrus.Infof("[workerplane] Processing host %v", runHost.HostnameOverride)
				if newHosts[runHost.HostnameOverride] {
					if err := doDeployWorkerPlaneHost(ctx, runHost, localConnDialerFactory, prsMap, workerNodePlanMap[runHost.Address].Processes, certMap, updateWorkersOnly, alpineImage); err != nil {
						errList = append(errList, err)
						hostsFailedToUpgrade <- runHost.HostnameOverride
						hostsFailed.Store(runHost.HostnameOverride, true)
						break
					}
					continue
				}
				nodes, err := getNodeListForUpgrade(kubeClient, &hostsFailed, newHosts, true)
				if err != nil {
					errList = append(errList, err)
				}
				var maxUnavailableHit bool
				for _, node := range nodes {
					// in case any previously added nodes or till now unprocessed nodes become unreachable during upgrade
					if !k8s.IsNodeReady(node) {
						if len(hostsFailedToUpgrade) >= maxUnavailable {
							maxUnavailableHit = true
							break
						}
						hostsFailed.Store(node.Labels[k8s.HostnameLabel], true)
						hostsFailedToUpgrade <- node.Labels[k8s.HostnameLabel]
						errList = append(errList, fmt.Errorf("host %v not ready", node.Labels[k8s.HostnameLabel]))
					}
				}
				if maxUnavailableHit || len(hostsFailedToUpgrade) >= maxUnavailable {
					break
				}
				upgradable, err := isWorkerHostUpgradable(ctx, runHost, workerNodePlanMap[runHost.Address].Processes)
				if err != nil {
					errList = append(errList, err)
					hostsFailed.Store(runHost.HostnameOverride, true)
					hostsFailedToUpgrade <- runHost.HostnameOverride
					break
				}
				if !upgradable {
					logrus.Infof("[workerplane] Upgrade not required for worker components of host %v", runHost.HostnameOverride)
					continue
				}
				if err := upgradeWorkerHost(ctx, kubeClient, runHost, upgradeStrategy.Drain, drainHelper, localConnDialerFactory, prsMap, workerNodePlanMap, certMap, updateWorkersOnly, alpineImage); err != nil {
					errList = append(errList, err)
					hostsFailed.Store(runHost.HostnameOverride, true)
					hostsFailedToUpgrade <- runHost.HostnameOverride
					break
				}
			}
			return util.ErrList(errList)
		})
	}

	err := errgrp.Wait()
	close(hostsFailedToUpgrade)
	if err != nil {
		for host := range hostsFailedToUpgrade {
			failedHosts = append(failedHosts, host)
		}
	}
	return failedHosts, err
}

func upgradeWorkerHost(ctx context.Context, kubeClient *kubernetes.Clientset, runHost *hosts.Host, drainFlag bool, drainHelper drain.Helper,
	localConnDialerFactory hosts.DialerFactory, prsMap map[string]v3.PrivateRegistry, workerNodePlanMap map[string]v3.RKEConfigNodePlan, certMap map[string]pki.CertificatePKI, updateWorkersOnly bool,
	alpineImage string) error {
	if err := checkNodeReady(kubeClient, runHost, WorkerRole); err != nil {
		return err
	}
	// cordon and drain
	if err := cordonAndDrainNode(kubeClient, runHost, drainFlag, drainHelper, WorkerRole); err != nil {
		return err
	}
	logrus.Debugf("[workerplane] upgrading host %v", runHost.HostnameOverride)
	if err := doDeployWorkerPlaneHost(ctx, runHost, localConnDialerFactory, prsMap, workerNodePlanMap[runHost.Address].Processes, certMap, updateWorkersOnly, alpineImage); err != nil {
		return err
	}
	// consider upgrade done when kubeclient lists node as ready
	if err := checkNodeReady(kubeClient, runHost, WorkerRole); err != nil {
		return err
	}
	// uncordon node
	if err := k8s.CordonUncordon(kubeClient, runHost.HostnameOverride, false); err != nil {
		return err
	}
	return nil
}

func doDeployWorkerPlaneHost(ctx context.Context, host *hosts.Host, localConnDialerFactory hosts.DialerFactory, prsMap map[string]v3.PrivateRegistry, processMap map[string]v3.Process, certMap map[string]pki.CertificatePKI, updateWorkersOnly bool, alpineImage string) error {
	if updateWorkersOnly {
		if !host.UpdateWorker {
			return nil
		}
	}
	if !host.IsWorker {
		if host.IsEtcd {
			// Add unschedulable taint
			host.ToAddTaints = append(host.ToAddTaints, unschedulableEtcdTaint)
		}
		if host.IsControl {
			// Add unschedulable taint
			host.ToAddTaints = append(host.ToAddTaints, unschedulableControlTaint)
		}
	}
	return doDeployWorkerPlane(ctx, host, localConnDialerFactory, prsMap, processMap, certMap, alpineImage)
}

func RemoveWorkerPlane(ctx context.Context, workerHosts []*hosts.Host, force bool) error {
	log.Infof(ctx, "[%s] Tearing down Worker Plane..", WorkerRole)
	var errgrp errgroup.Group
	hostsQueue := util.GetObjectQueue(workerHosts)
	for w := 0; w < WorkerThreads; w++ {
		errgrp.Go(func() error {
			var errList []error
			for host := range hostsQueue {
				runHost := host.(*hosts.Host)
				if runHost.IsControl && !force {
					log.Infof(ctx, "[%s] Host [%s] is already a controlplane host, nothing to do.", WorkerRole, runHost.Address)
					return nil
				}
				if err := removeKubelet(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
				if err := removeKubeproxy(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
				if err := removeNginxProxy(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
				if err := removeSidekick(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
			}
			return util.ErrList(errList)
		})
	}

	if err := errgrp.Wait(); err != nil {
		return err
	}
	log.Infof(ctx, "[%s] Successfully tore down Worker Plane..", WorkerRole)

	return nil
}

func RestartWorkerPlane(ctx context.Context, workerHosts []*hosts.Host) error {
	log.Infof(ctx, "[%s] Restarting Worker Plane..", WorkerRole)
	var errgrp errgroup.Group

	hostsQueue := util.GetObjectQueue(workerHosts)
	for w := 0; w < WorkerThreads; w++ {
		errgrp.Go(func() error {
			var errList []error
			for host := range hostsQueue {
				runHost := host.(*hosts.Host)
				if err := RestartKubelet(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
				if err := RestartKubeproxy(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
				if err := RestartNginxProxy(ctx, runHost); err != nil {
					errList = append(errList, err)
				}
			}
			return util.ErrList(errList)
		})
	}
	if err := errgrp.Wait(); err != nil {
		return err
	}
	log.Infof(ctx, "[%s] Successfully restarted Worker Plane..", WorkerRole)

	return nil
}

func doDeployWorkerPlane(ctx context.Context, host *hosts.Host,
	localConnDialerFactory hosts.DialerFactory,
	prsMap map[string]v3.PrivateRegistry, processMap map[string]v3.Process, certMap map[string]pki.CertificatePKI, alpineImage string) error {
	// run nginx proxy
	if !host.IsControl {
		if err := runNginxProxy(ctx, host, prsMap, processMap[NginxProxyContainerName], alpineImage); err != nil {
			return err
		}
	}
	// run sidekick
	if err := runSidekick(ctx, host, prsMap, processMap[SidekickContainerName]); err != nil {
		return err
	}
	// run kubelet
	if err := runKubelet(ctx, host, localConnDialerFactory, prsMap, processMap[KubeletContainerName], certMap, alpineImage); err != nil {
		return err
	}
	return runKubeproxy(ctx, host, localConnDialerFactory, prsMap, processMap[KubeproxyContainerName], alpineImage)
}

func isWorkerHostUpgradable(ctx context.Context, host *hosts.Host, processMap map[string]v3.Process) (bool, error) {
	for _, service := range []string{NginxProxyContainerName, SidekickContainerName, KubeletContainerName, KubeproxyContainerName} {
		process := processMap[service]
		imageCfg, hostCfg, _ := GetProcessConfig(process, host)
		upgradable, err := docker.IsContainerUpgradable(ctx, host.DClient, imageCfg, hostCfg, service, host.Address, WorkerRole)
		if err != nil {
			if client.IsErrNotFound(err) {
				if service == NginxProxyContainerName && host.IsControl {
					// nginxProxy should not exist on control hosts, so no changes needed
					continue
				}
				// doDeployWorkerPlane should be called so this container gets recreated
				logrus.Debugf("[%s] Host %v is upgradable because %v needs to run", WorkerRole, host.HostnameOverride, service)
				return true, nil
			}
			return false, err
		}
		if upgradable {
			logrus.Debugf("[%s] Host %v is upgradable because %v has changed", WorkerRole, host.HostnameOverride, service)
			// host upgradable even if a single service is upgradable
			return true, nil
		}
	}
	logrus.Debugf("[%s] Host %v is not upgradable", WorkerRole, host.HostnameOverride)
	return false, nil
}
